{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rYzDXA9vJfz1"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import gym\n",
    "from stable_baselines3.common.env_checker import check_env\n",
    "\n",
    "\n",
    "class GoLeftEnv(gym.Env):\n",
    "    #支持的render模式,在jupyter中不支持human模式\n",
    "    metadata = {'render.modes': ['console']}\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        #初始位置\n",
    "        self.pos = 9\n",
    "\n",
    "        #动作空间,这个环境中只有左,右两个动作\n",
    "        self.action_space = gym.spaces.Discrete(2)\n",
    "\n",
    "        #状态空间,一维数轴\n",
    "        self.observation_space = gym.spaces.Box(low=0,\n",
    "                                                high=10,\n",
    "                                                shape=(1, ),\n",
    "                                                dtype=np.float32)\n",
    "\n",
    "    def reset(self):\n",
    "        #重置位置\n",
    "        self.pos = 9\n",
    "\n",
    "        #当前状态\n",
    "        return np.array([self.pos], dtype=np.float32)\n",
    "\n",
    "    def step(self, action):\n",
    "        #执行动作\n",
    "        if action == 0:\n",
    "            self.pos -= 1\n",
    "\n",
    "        if action == 1:\n",
    "            self.pos += 1\n",
    "\n",
    "        self.pos = np.clip(self.pos, 0, 10)\n",
    "\n",
    "        #判断游戏结束\n",
    "        done = self.pos == 0\n",
    "\n",
    "        #给予reward\n",
    "        reward = 1 if self.pos == 0 else 0\n",
    "\n",
    "        return np.array([self.pos], dtype=np.float32), reward, bool(done), {}\n",
    "\n",
    "    def render(self, mode='console'):\n",
    "        if mode != 'console':\n",
    "            raise NotImplementedError()\n",
    "        print(self.pos)\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "\n",
    "env = GoLeftEnv()\n",
    "\n",
    "#检查环境是否合法\n",
    "check_env(env, warn=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.env_util import make_vec_env\n",
    "\n",
    "#包装环境\n",
    "train_env = make_vec_env(lambda: env, n_envs=1)\n",
    "\n",
    "#定义模型\n",
    "model = PPO('MlpPolicy', train_env, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 [9.] 1 0\n",
      "1 [10.] 0 0\n",
      "2 [9.] 0 0\n",
      "3 [8.] 0 0\n",
      "4 [7.] 0 0\n",
      "5 [6.] 1 0\n",
      "6 [7.] 1 0\n",
      "7 [8.] 0 0\n",
      "8 [7.] 0 0\n",
      "9 [6.] 0 0\n",
      "10 [5.] 1 0\n",
      "11 [6.] 0 0\n",
      "12 [5.] 0 0\n",
      "13 [4.] 1 0\n",
      "14 [5.] 0 0\n",
      "15 [4.] 1 0\n",
      "16 [5.] 0 0\n",
      "17 [4.] 1 0\n",
      "18 [5.] 1 0\n",
      "19 [6.] 1 0\n",
      "20 [7.] 0 0\n",
      "21 [6.] 0 0\n",
      "22 [5.] 1 0\n",
      "23 [6.] 1 0\n",
      "24 [7.] 0 0\n",
      "25 [6.] 0 0\n",
      "26 [5.] 1 0\n",
      "27 [6.] 0 0\n",
      "28 [5.] 0 0\n",
      "29 [4.] 0 0\n",
      "30 [3.] 1 0\n",
      "31 [4.] 0 0\n",
      "32 [3.] 0 0\n",
      "33 [2.] 1 0\n",
      "34 [3.] 0 0\n",
      "35 [2.] 0 0\n",
      "36 [1.] 0 1\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#测试一个环境\n",
    "def test(model, env):\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "    step = 0\n",
    "\n",
    "    for i in range(100):\n",
    "        action = model.predict(state)[0]\n",
    "\n",
    "        next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "        if step % 1 == 0:\n",
    "            print(step, state, action, reward)\n",
    "\n",
    "        state = next_state\n",
    "        step += 1\n",
    "\n",
    "        if over:\n",
    "            break\n",
    "\n",
    "\n",
    "test(model, env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PQfLBE28SNDr"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 [9.] 0 0\n",
      "1 [8.] 0 0\n",
      "2 [7.] 0 0\n",
      "3 [6.] 0 0\n",
      "4 [5.] 0 0\n",
      "5 [4.] 0 0\n",
      "6 [3.] 0 0\n",
      "7 [2.] 0 0\n",
      "8 [1.] 0 1\n"
     ]
    }
   ],
   "source": [
    "model.learn(5000)\n",
    "\n",
    "#测试\n",
    "test(model, env)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "5.custom_gym_env.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
